import torch, functools
import torch.nn as nn
import torch_scatter
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing, radius_graph
from torch_scatter import scatter_add
from torch_scatter import scatter

_NUM_ATOM_TYPES = 9
_DEFAULT_V_DIM = (128, 16)
_DEFAULT_E_DIM = (64, 1)
num_ss_cls = 8

def _normalize(tensor, dim=-1):
    '''
    Normalizes a `torch.Tensor` along dimension `dim` without `nan`s.
    '''
    return torch.nan_to_num(
        torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True)))

def _rbf(D, D_min=0., D_max=20., D_count=16, device='cpu'):
    '''
    From https://github.com/jingraham/neurips19-graph-protein-design
    
    Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1.
    That is, if `D` has shape [...dims], then the returned tensor will have
    shape [...dims, D_count].
    '''
    D_mu = torch.linspace(D_min, D_max, D_count, device=device)
    D_mu = D_mu.view([1, -1])
    D_sigma = (D_max - D_min) / D_count
    D_expand = torch.unsqueeze(D, -1)

    RBF = torch.exp(-((D_expand - D_mu) / D_sigma) ** 2)
    return RBF

def _edge_features(coords, edge_index, D_max=4.5, num_rbf=16, device='cpu'):
    
        E_vectors = coords[edge_index[0]] - coords[edge_index[1]]
        rbf = _rbf(E_vectors.norm(dim=-1), 
                D_max=D_max, D_count=num_rbf, device=device)

        edge_s = rbf
        edge_v = _normalize(E_vectors).unsqueeze(-2)

        edge_s, edge_v = map(torch.nan_to_num,
                (edge_s, edge_v))

        return edge_s, edge_v

def tuple_sum(*args):
    '''
    Sums any number of tuples (s, V) elementwise.
    '''
    return tuple(map(sum, zip(*args)))

def tuple_cat(*args, dim=-1):
    '''
    Concatenates any number of tuples (s, V) elementwise.
    
    :param dim: dimension along which to concatenate when viewed
                as the `dim` index for the scalar-channel tensors.
                This means that `dim=-1` will be applied as
                `dim=-2` for the vector-channel tensors.
    '''
    dim %= len(args[0][0].shape)
    s_args, v_args = list(zip(*args))
    return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim)

def tuple_index(x, idx):
    '''
    Indexes into a tuple (s, V) along the first dimension.
    
    :param idx: any object which can be used to index into a `torch.Tensor`
    '''
    return x[0][idx], x[1][idx]

def randn(n, dims, device="cpu"):
    '''
    Returns random tuples (s, V) drawn elementwise from a normal distribution.
    
    :param n: number of data points
    :param dims: tuple of dimensions (n_scalar, n_vector)
    
    :return: (s, V) with s.shape = (n, n_scalar) and
             V.shape = (n, n_vector, 3)
    '''
    return torch.randn(n, dims[0], device=device), \
            torch.randn(n, dims[1], 3, device=device)

def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True):
    '''
    L2 norm of tensor clamped above a minimum value `eps`.
    
    :param sqrt: if `False`, returns the square of the L2 norm
    '''
    out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps)
    return torch.sqrt(out) if sqrt else out

def _split(x, nv):
    '''
    Splits a merged representation of (s, V) back into a tuple. 
    Should be used only with `_merge(s, V)` and only if the tuple 
    representation cannot be used.
    
    :param x: the `torch.Tensor` returned from `_merge`
    :param nv: the number of vector channels in the input to `_merge`
    '''
    v = torch.reshape(x[..., -3*nv:], x.shape[:-1] + (nv, 3))
    s = x[..., :-3*nv]
    return s, v

def _merge(s, v):
    '''
    Merges a tuple (s, V) into a single `torch.Tensor`, where the
    vector channels are flattened and appended to the scalar channels.
    Should be used only if the tuple representation cannot be used.
    Use `_split(x, nv)` to reverse.
    '''
    v = torch.reshape(v, v.shape[:-2] + (3*v.shape[-2],))
    return torch.cat([s, v], -1)


class GVPNet(nn.Module):
    '''
    A base 5-layer GVP-GNN for all ATOM3D tasks, using GVPs with 
    vector gating as described in the manuscript. Takes in atomic-level
    structure graphs of type `torch_geometric.data.Batch`
    and returns a single scalar.
    
    This class should not be used directly. Instead, please use the
    task-specific models which extend BaseModel. (Some of these classes
    may be aliases of BaseModel.)
    
    :param num_rbf: number of radial bases to use in the edge embedding
    '''
    def __init__(self, num_rbf=16, num_blocks=5,
                 out_channels=1,
                 cutoff=4, 
                 dropout=0.1,
                 side_chain_embs_dim=11, 
                 max_num_neighbors=32,
                 SS=False,
                 geo=False,
                 SS_add=False, 
                 num_ss=1):
        
        super().__init__()
        activations = (F.relu, None)
        self.SS = SS
        self.geo = geo
        self.SS_add = SS_add
        self.num_ss = num_ss
        self.cutoff = cutoff
        self.dropout = dropout
        self.num_blocks = num_blocks
        self.out_channels = out_channels
        self.side_chain_embs_dim = side_chain_embs_dim
        self.linear_sc_emb = nn.Linear(side_chain_embs_dim, _NUM_ATOM_TYPES)
        
        self.num_rbf = num_rbf
        self.embed = nn.Embedding(_NUM_ATOM_TYPES, _NUM_ATOM_TYPES)
        self.max_num_neighbors = max_num_neighbors

        self.W_a_e = nn.Sequential(
            LayerNorm((num_rbf, 1)),
            GVP((num_rbf, 1), _DEFAULT_E_DIM, 
                activations=(None, None), vector_gate=True)
        )

        self.W_a_v = nn.Sequential(
            LayerNorm((self.side_chain_embs_dim, 0)),
            GVP((self.side_chain_embs_dim, 0), _DEFAULT_V_DIM,
                activations=(None, None), vector_gate=True)
        )

        self.layers_a = nn.ModuleList(
                GVPConvLayer(_DEFAULT_V_DIM, _DEFAULT_E_DIM, 
                             activations=activations, vector_gate=True) 
            for _ in range(self.num_blocks-self.num_ss))
        
        ns, _ = _DEFAULT_V_DIM
        self.W_a_out = nn.Sequential(
            LayerNorm(_DEFAULT_V_DIM),
            GVP(_DEFAULT_V_DIM, (ns, 0), 
                activations=activations, vector_gate=True)
        )
        
        if SS or SS_add:    

            self.logit0 = nn.Parameter(torch.tensor(0.0))
            self.logit1 = nn.Parameter(torch.tensor(0.0))

            if geo:
                add_dim = 9
            else:
                add_dim = 0
            self.W_b_v = nn.Sequential(
                LayerNorm((num_ss_cls+add_dim, 0)),
                GVP((num_ss_cls+add_dim, 0), _DEFAULT_V_DIM,
                    activations=(None, None), vector_gate=True)
            )
            self.W_b_e = nn.Sequential(
                LayerNorm((num_rbf, 1)),
                GVP((num_rbf, 1), _DEFAULT_E_DIM,
                    activations=(None, None), vector_gate=True)
            )
            self.layers_b = nn.ModuleList(
                GVPConvLayer(_DEFAULT_V_DIM, _DEFAULT_E_DIM,
                                activations=activations, vector_gate=True)
                for _ in range(self.num_ss))
            self.W_b_out = nn.Sequential(
                LayerNorm(_DEFAULT_V_DIM),
                GVP(_DEFAULT_V_DIM, (ns, 0),
                    activations=activations, vector_gate=True)
            )

        self.dense = nn.Sequential(
            nn.Linear(ns, 2*ns), nn.ReLU(inplace=True),
            nn.Dropout(p=self.dropout),
            nn.Linear(2*ns, self.out_channels)
        )

        self.num_edges_ca = 0
        self.num_edges_ss = 0

    def forward(self, batch, scatter_mean=True, dense=True):
        '''
        Forward pass which can be adjusted based on task formulation.
        
        :param batch: `torch_geometric.data.Batch` with data attributes
                      as returned from a BaseTransform
        :param scatter_mean: if `True`, returns mean of final node embeddings
                             (for each graph), else, returns embeddings seperately
        :param dense: if `True`, applies final dense layer to reduce embedding
                      to a single scalar; else, returns the embedding
        '''
        
        h_a_V = batch.side_chain_embs
        edge_index = batch.edge_index
        self.num_edges_ss += edge_index.shape[1]

        pos_a_ca = batch.coords_a_ca
        device = batch.coords_a_ca.device
        if self.cutoff < 4:
            edge_index_a = edge_index
        else:
            edge_index_a = radius_graph(pos_a_ca, r=self.cutoff,
                                        max_num_neighbors=self.max_num_neighbors)
        self.num_edges_ca += edge_index_a.shape[1]
        edge_a_s, edge_a_v = _edge_features(pos_a_ca, edge_index_a, num_rbf=self.num_rbf, device=device)

        h_a_E = (edge_a_s, edge_a_v)
        h_a_V = self.W_a_v(h_a_V)
        h_a_E = self.W_a_e(h_a_E)

        if not self.SS_add:
            for layer in self.layers_a:
                h_a_V = layer(h_a_V, edge_index_a, h_a_E)

        if self.SS_add and not self.SS:
            z_b = batch.ss_x.long()
            ss_x = torch.squeeze(F.one_hot(z_b, num_classes=num_ss_cls).float())
            mapping_a_to_b = batch.mapping_a_to_b
            if self.geo:
                b_frame_R_ts = batch.b_frame_R_ts.reshape(-1, 9)
                ss_x = torch.cat([ss_x, b_frame_R_ts], dim=1)
            h_b_V = self.W_b_v(ss_x)
            h_ss_add_0 = h_b_V[0][mapping_a_to_b]
            h_ss_add_1 = h_b_V[1][mapping_a_to_b]

            w0 = torch.sigmoid(self.logit0)
            w1 = torch.sigmoid(self.logit1)

            h_a_V = list(h_a_V) 
            h_a_V[0] = w0 * h_ss_add_0 + (1-w0) * h_a_V[0]
            h_a_V[1] = w1 * h_ss_add_1 + (1-w1) * h_a_V[1]
            h_a_V = tuple(h_a_V)

            for layer in self.layers_a:
                h_a_V = layer(h_a_V, edge_index_a, h_a_E)
            out = self.W_a_out(h_a_V)
            batch_id = batch.batch

        elif self.SS and not self.SS_add:

            z_b = batch.ss_x.long()

            mapping_a_to_b = batch.mapping_a_to_b

            h_a_V_0 = scatter(h_a_V[0], mapping_a_to_b, dim=0, dim_size=len(z_b), reduce='mean')
            h_a_V_1 = scatter(h_a_V[1], mapping_a_to_b, dim=0, dim_size=len(z_b), reduce='mean')
            
            pos_b = batch.coords_b_
           
            edge_index_b = batch.ch_b_edge_index
            self.num_edges_ca += edge_index_b.shape[1]

            ss_x = torch.squeeze(F.one_hot(z_b, num_classes=num_ss_cls).float())
            if self.geo:
                b_frame_R_ts = batch.b_frame_R_ts.reshape(-1, 9)
                ss_x = torch.cat([ss_x, b_frame_R_ts], dim=1)
            h_b_V = self.W_b_v(ss_x)
            
            w0 = torch.sigmoid(self.logit0)
            w1 = torch.sigmoid(self.logit1)

            h_b_V = list(h_b_V) 
            h_b_V[0] = w0 * h_b_V[0] + (1-w0) * h_a_V_0
            h_b_V[1] = w1 * h_b_V[1] + (1-w1) * h_a_V_1
            h_b_V = tuple(h_b_V)
            
            edge_b_s, edge_b_v = _edge_features(pos_b, edge_index_b, num_rbf=self.num_rbf, device=device)

            h_b_E = (edge_b_s, edge_b_v)
            h_b_E = self.W_b_e(h_b_E)

            for layer in self.layers_b:
                h_b_V = layer(h_b_V, edge_index_b, h_b_E)
            out = self.W_b_out(h_b_V)
            
            num_nodes_b = batch.num_nodes_b
            for i in range(len(num_nodes_b)):
                if i == 0:
                    batch_s = torch.zeros(num_nodes_b[i], device=device)
                else:
                    batch_s = torch.cat((batch_s, torch.ones(num_nodes_b[i], device=device)*i))
            batch_id = batch_s.long()

        elif self.SS_add and self.SS:
            raise ValueError("SS and SS_add cannot be True at the same time")
        
        else:
            out = self.W_a_out(h_a_V)
            batch_id = batch.batch

        out = torch_scatter.scatter_mean(out, batch_id, dim=0)
        out = self.dense(out).squeeze(-1)

        return out
    

class _VDropout(nn.Module):
    '''
    Vector channel dropout where the elements of each
    vector channel are dropped together.
    '''
    def __init__(self, drop_rate):
        super(_VDropout, self).__init__()
        self.drop_rate = drop_rate
        self.dummy_param = nn.Parameter(torch.empty(0))

    def forward(self, x):
        '''
        :param x: `torch.Tensor` corresponding to vector channels
        '''
        device = self.dummy_param.device
        if not self.training:
            return x
        mask = torch.bernoulli(
            (1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device)
        ).unsqueeze(-1)
        x = mask * x / (1 - self.drop_rate)
        return x
    
class Dropout(nn.Module):
    '''
    Combined dropout for tuples (s, V).
    Takes tuples (s, V) as input and as output.
    '''
    def __init__(self, drop_rate):
        super(Dropout, self).__init__()
        self.sdropout = nn.Dropout(drop_rate)
        self.vdropout = _VDropout(drop_rate)

    def forward(self, x):
        '''
        :param x: tuple (s, V) of `torch.Tensor`,
                  or single `torch.Tensor` 
                  (will be assumed to be scalar channels)
        '''
        if type(x) is torch.Tensor:
            return self.sdropout(x)
        s, v = x
        return self.sdropout(s), self.vdropout(v)
    
class LayerNorm(nn.Module):
    '''
    Combined LayerNorm for tuples (s, V).
    Takes tuples (s, V) as input and as output.
    '''
    def __init__(self, dims):
        super(LayerNorm, self).__init__()
        self.s, self.v = dims
        self.scalar_norm = nn.LayerNorm(self.s)
        
    def forward(self, x):
        '''
        :param x: tuple (s, V) of `torch.Tensor`,
                  or single `torch.Tensor` 
                  (will be assumed to be scalar channels)
        '''
        if not self.v:
            return self.scalar_norm(x)
        s, v = x
        vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False)
        vn = torch.sqrt(torch.mean(vn, dim=-2, keepdim=True))
        return self.scalar_norm(s), v / vn
    
class GVP(nn.Module):
    '''
    Geometric Vector Perceptron. See manuscript and README.md
    for more details.
    
    :param in_dims: tuple (n_scalar, n_vector)
    :param out_dims: tuple (n_scalar, n_vector)
    :param h_dim: intermediate number of vector channels, optional
    :param activations: tuple of functions (scalar_act, vector_act)
    :param vector_gate: whether to use vector gating.
                        (vector_act will be used as sigma^+ in vector gating if `True`)
    '''
    def __init__(self, in_dims, out_dims, h_dim=None,
                 activations=(F.relu, torch.sigmoid), vector_gate=False):
        super(GVP, self).__init__()
        self.si, self.vi = in_dims
        self.so, self.vo = out_dims
        self.vector_gate = vector_gate
        if self.vi: 
            self.h_dim = h_dim or max(self.vi, self.vo) 
            self.wh = nn.Linear(self.vi, self.h_dim, bias=False)
            self.ws = nn.Linear(self.h_dim + self.si, self.so)
            if self.vo:
                self.wv = nn.Linear(self.h_dim, self.vo, bias=False)
                if self.vector_gate: self.wsv = nn.Linear(self.so, self.vo)
        else:
            self.ws = nn.Linear(self.si, self.so)
        
        self.scalar_act, self.vector_act = activations
        self.dummy_param = nn.Parameter(torch.empty(0))
        
    def forward(self, x):
        '''
        :param x: tuple (s, V) of `torch.Tensor`, 
                  or (if vectors_in is 0), a single `torch.Tensor`
        :return: tuple (s, V) of `torch.Tensor`,
                 or (if vectors_out is 0), a single `torch.Tensor`
        '''
        if self.vi:
            s, v = x
            v = torch.transpose(v, -1, -2)
            vh = self.wh(v)    
            vn = _norm_no_nan(vh, axis=-2)
            s = self.ws(torch.cat([s, vn], -1))
            if self.vo: 
                v = self.wv(vh) 
                v = torch.transpose(v, -1, -2)
                if self.vector_gate: 
                    if self.vector_act:
                        gate = self.wsv(self.vector_act(s))
                    else:
                        gate = self.wsv(s)
                    v = v * torch.sigmoid(gate).unsqueeze(-1)
                elif self.vector_act:
                    v = v * self.vector_act(
                        _norm_no_nan(v, axis=-1, keepdims=True))
        else:
            s = self.ws(x)
            if self.vo:
                v = torch.zeros(s.shape[0], self.vo, 3,
                                device=self.dummy_param.device)
        if self.scalar_act:
            s = self.scalar_act(s)
        
        return (s, v) if self.vo else s
    
class GVPConv(MessagePassing):
    '''
    Graph convolution / message passing with Geometric Vector Perceptrons.
    Takes in a graph with node and edge embeddings,
    and returns new node embeddings.
    
    This does NOT do residual updates and pointwise feedforward layers
    ---see `GVPConvLayer`.
    
    :param in_dims: input node embedding dimensions (n_scalar, n_vector)
    :param out_dims: output node embedding dimensions (n_scalar, n_vector)
    :param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
    :param n_layers: number of GVPs in the message function
    :param module_list: preconstructed message function, overrides n_layers
    :param aggr: should be "add" if some incoming edges are masked, as in
                 a masked autoregressive decoder architecture, otherwise "mean"
    :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
    :param vector_gate: whether to use vector gating.
                        (vector_act will be used as sigma^+ in vector gating if `True`)
    '''
    def __init__(self, in_dims, out_dims, edge_dims,
                 n_layers=3, module_list=None, aggr="mean", 
                 activations=(F.relu, torch.sigmoid), vector_gate=False):
        super(GVPConv, self).__init__(aggr=aggr)
        self.si, self.vi = in_dims
        self.so, self.vo = out_dims
        self.se, self.ve = edge_dims
        
        GVP_ = functools.partial(GVP, 
                activations=activations, vector_gate=vector_gate)
        
        module_list = module_list or []
        if not module_list:
            if n_layers == 1:
                module_list.append(
                    GVP_((2*self.si + self.se, 2*self.vi + self.ve), 
                        (self.so, self.vo), activations=(None, None)))
            else:
                module_list.append(
                    GVP_((2*self.si + self.se, 2*self.vi + self.ve), out_dims)
                )
                for i in range(n_layers - 2):
                    module_list.append(GVP_(out_dims, out_dims))
                module_list.append(GVP_(out_dims, out_dims,
                                       activations=(None, None)))
        self.message_func = nn.Sequential(*module_list)

    def forward(self, x, edge_index, edge_attr):
        '''
        :param x: tuple (s, V) of `torch.Tensor`
        :param edge_index: array of shape [2, n_edges]
        :param edge_attr: tuple (s, V) of `torch.Tensor`
        '''
        x_s, x_v = x
        message = self.propagate(edge_index, 
                    s=x_s, v=x_v.reshape(x_v.shape[0], 3*x_v.shape[1]),
                    edge_attr=edge_attr)
        return _split(message, self.vo) 

    def message(self, s_i, v_i, s_j, v_j, edge_attr):
        v_j = v_j.view(v_j.shape[0], v_j.shape[1]//3, 3)
        v_i = v_i.view(v_i.shape[0], v_i.shape[1]//3, 3)
        message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i))
        message = self.message_func(message)
        return _merge(*message)
    
class GVPConvLayer(nn.Module):
    '''
    Full graph convolution / message passing layer with 
    Geometric Vector Perceptrons. Residually updates node embeddings with
    aggregated incoming messages, applies a pointwise feedforward 
    network to node embeddings, and returns updated node embeddings.
    
    To only compute the aggregated messages, see `GVPConv`.
    
    :param node_dims: node embedding dimensions (n_scalar, n_vector)
    :param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
    :param n_message: number of GVPs to use in message function
    :param n_feedforward: number of GVPs to use in feedforward function
    :param drop_rate: drop probability in all dropout layers
    :param autoregressive: if `True`, this `GVPConvLayer` will be used
           with a different set of input node embeddings for messages
           where src >= dst
    :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
    :param vector_gate: whether to use vector gating.
                        (vector_act will be used as sigma^+ in vector gating if `True`)
    '''
    def __init__(self, node_dims, edge_dims,
                 n_message=3, n_feedforward=2, drop_rate=.1,
                 autoregressive=False, 
                 activations=(F.relu, torch.sigmoid), vector_gate=False):
        
        super(GVPConvLayer, self).__init__()
        self.conv = GVPConv(node_dims, node_dims, edge_dims, n_message,
                           aggr="add" if autoregressive else "mean",
                           activations=activations, vector_gate=vector_gate)
        GVP_ = functools.partial(GVP, 
                activations=activations, vector_gate=vector_gate)
        self.norm = nn.ModuleList([LayerNorm(node_dims) for _ in range(2)])
        self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)])

        ff_func = []
        if n_feedforward == 1:
            ff_func.append(GVP_(node_dims, node_dims, activations=(None, None)))
        else:
            hid_dims = 4*node_dims[0], 2*node_dims[1]
            ff_func.append(GVP_(node_dims, hid_dims))
            for i in range(n_feedforward-2):
                ff_func.append(GVP_(hid_dims, hid_dims))
            ff_func.append(GVP_(hid_dims, node_dims, activations=(None, None)))
        self.ff_func = nn.Sequential(*ff_func)

    def forward(self, x, edge_index, edge_attr,
                autoregressive_x=None, node_mask=None):
        '''
        :param x: tuple (s, V) of `torch.Tensor`
        :param edge_index: array of shape [2, n_edges]
        :param edge_attr: tuple (s, V) of `torch.Tensor`
        :param autoregressive_x: tuple (s, V) of `torch.Tensor`. 
                If not `None`, will be used as src node embeddings
                for forming messages where src >= dst. The corrent node 
                embeddings `x` will still be the base of the update and the 
                pointwise feedforward.
        :param node_mask: array of type `bool` to index into the first
                dim of node embeddings (s, V). If not `None`, only
                these nodes will be updated.
        '''
        
        if autoregressive_x is not None:
            src, dst = edge_index
            mask = src < dst
            edge_index_forward = edge_index[:, mask]
            edge_index_backward = edge_index[:, ~mask]
            edge_attr_forward = tuple_index(edge_attr, mask)
            edge_attr_backward = tuple_index(edge_attr, ~mask)
            
            dh = tuple_sum(
                self.conv(x, edge_index_forward, edge_attr_forward),
                self.conv(autoregressive_x, edge_index_backward, edge_attr_backward)
            )
            
            count = scatter_add(torch.ones_like(dst), dst,
                        dim_size=dh[0].size(0)).clamp(min=1).unsqueeze(-1)
            
            dh = dh[0] / count, dh[1] / count.unsqueeze(-1)

        else:
            dh = self.conv(x, edge_index, edge_attr)
        
        if node_mask is not None:
            x_ = x
            x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask)
            
        x = self.norm[0](tuple_sum(x, self.dropout[0](dh)))
        
        dh = self.ff_func(x)
        x = self.norm[1](tuple_sum(x, self.dropout[1](dh)))
        
        if node_mask is not None:
            x_[0][node_mask], x_[1][node_mask] = x[0], x[1]
            x = x_
        return x